# -- coding=utf-8 --
import os
import time

import requests
import base64
import json
import logging as log
import cv2
from PIL import Image
import numpy as np

log.basicConfig(level=log.INFO)

SERVER_URL = 'http://123.60.231.101:30090/infer'
TASK_NAME = 'assm1-2'
model_pro_dic = {
    'ssd': 'project_ssd',
    'ssd+resnet': 'project_ssd_resnet',
    'yolov4': 'project_yolov4',
    'yolov4+resnet': 'project_yolov4_resnet',
    'unet': 'project_unet',
    'unet++': 'project_unet++',
    'crnn': 'project_crnn',
    'ctpn_crnn': 'project_ctpn_crnn',
    'yolov4+resnet_ch': 'project_yolov4_resnet_ch',
}
CA_DIR = './cert_files/'


def build_req(model_name,
              task_name,
              with_reg=False,
              pic_path='',
              project_name="",
              **ext_params):
    """
    构造请求
    :param project_name:
    :param pic_path: 图片路径
    :param model_name:
    :param task_name:
    :param with_reg:
    :return:
    """
    image_path = 'reqs/' + model_name + ".jpg" if not pic_path else pic_path
    # json_file='mxAoi_req.json'
    # req_json=json.load(mxAoi_req)
    try:
        with open(image_path, 'rb') as f1:
            image_data = base64.b64encode(f1.read())
    except FileNotFoundError as e:
        log.error(f"File: ({image_path}) Not Found!. error:{e}")
        raise e
    task_type = "WithReg" if with_reg else "WithoutReg"
    rois = get_rois() if with_reg else []
    mx_aoi_req = {
        "TaskName": task_name,
        "ProjectName": project_name,
        "TaskType": task_type,
        "MachineId": "bd11",
        "ProductCode": "p223",
        "CameraID": "c12",
        "TableName": "AI_Test",
        "ROIs": rois,
        "CodeType": "jpg",
        "Patterns": get_patterns(),
        "Files": [{
            "FileName": image_path,
            "ImageData": image_data.decode()
        }],
        "ExtraParams": {},
        "MatchParams": get_cls_match_params()
        if "unet" not in model_name else get_unet_match_params(),
        "RegisterMatrix": [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
    }
    mx_aoi_req.update(ext_params)
    save_req_json(project_name, mx_aoi_req)
    return mx_aoi_req


def get_rois():
    return [
        {
            "Height": 2419.5469879518073,
            "Type": "include",
            "Width": 3210.6795180722884,
            "X": 843.8746987951808,
            "Y": 494.45783132530096
        }
    ]


def get_cls_match_params():
    return {
        "global_thresh": {
            "conf_thresh": 0.5,
            "overlap_thresh": 0.3,
            "overlap_metric": "IOU"
        },
        "label_spec_thresh": [{
            "label": "Screw",
            "conf_thresh": 0.5,
            "overlap_thresh": 0.3,
            "overlap_metric": "IOU"
        }]
    }


def get_unet_match_params():
    return {"overlap_thresh": 0.5, "overlap_metric": "IOU"}


def get_patterns():
    return [{
        'Height': 150,
        'Label': 'Screw',
        'Width': 139,
        'X': 967,
        'Y': 2341,
        'detail_label': '',
        'num_match': 1,
        'overlap_metric': 'IOU',
        'overlap_thresh': 0.5,
        'reserved1': '',
        'reserved2': '',
        'reserved3': '',
        'reserved4': '',
        'reserved5': '',
        'reserved6': ''
    }, {
        'Height': 156,
        'Label': 'Screw',
        'Width': 144,
        'X': 1004,
        'Y': 814,
        'detail_label': '',
        'num_match': 1,
        'overlap_metric': 'IOU',
        'overlap_thresh': 0.5,
        'reserved1': '',
        'reserved2': '',
        'reserved3': '',
        'reserved4': '',
        'reserved5': '',
        'reserved6': ''
    }, {
        'Height': 155,
        'Label': 'Screw',
        'Width': 150,
        'X': 3578,
        'Y': 869,
        'detail_label': '',
        'num_match': 1,
        'overlap_metric': 'IOU',
        'overlap_thresh': 0.5,
        'reserved1': '',
        'reserved2': '',
        'reserved3': '',
        'reserved4': '',
        'reserved5': '',
        'reserved6': ''
    }, {
        'Height': 148,
        'Label': 'Screw',
        'Width': 135,
        'X': 3463,
        'Y': 2402,
        'detail_label': '',
        'num_match': 1,
        'overlap_metric': 'IOU',
        'overlap_thresh': 0.5,
        'reserved1': '',
        'reserved2': '',
        'reserved3': '',
        'reserved4': '',
        'reserved5': '',
        'reserved6': ''
    }, {
        'Height': 164,
        'Label': 'Screw',
        'Width': 160,
        'X': 1978,
        'Y': 2107,
        'detail_label': '',
        'num_match': 1,
        'overlap_metric': 'IOU',
        'overlap_thresh': 0.5,
        'reserved1': '',
        'reserved2': '',
        'reserved3': '',
        'reserved4': '',
        'reserved5': '',
        'reserved6': ''
    }, {
        'Height': 147,
        'Label': 'Screw',
        'Width': 144,
        'X': 1138,
        'Y': 1104,
        'detail_label': '',
        'num_match': 1,
        'overlap_metric': 'IOU',
        'overlap_thresh': 0.5,
        'reserved1': '',
        'reserved2': '',
        'reserved3': '',
        'reserved4': '',
        'reserved5': '',
        'reserved6': ''
    }, {
        'Height': 85,
        'Label': 'Screw',
        'Width': 83,
        'X': 1766,
        'Y': 1510,
        'detail_label': '',
        'num_match': 1,
        'overlap_metric': 'IOU',
        'overlap_thresh': 0.5,
        'reserved1': '',
        'reserved2': '',
        'reserved3': '',
        'reserved4': '',
        'reserved5': '',
        'reserved6': ''
    }]


def get_project_name(model_name):
    if not model_pro_dic.get(model_name):
        log.info(
            f'{model_name} => None . model not found, using model_name as project_name'
        )
        project_name = model_name
    else:
        project_name = model_pro_dic.get(model_name)
        log.info(f'{model_name} => {project_name}. find project by model')
    return project_name


def save_rsp_json(model_name, rsp, parent_dir='rsps/', taks_type=''):
    """
    返回结果存盘
    :param model_name:
    :param rsp:
    :param parent_dir: 存储json的目录
    :param taks_type:
    :return:
    """
    rsp_str = json.dumps(rsp, default=lambda obj: obj.__dict__)
    taks_type = taks_type if not taks_type else '-' + taks_type
    if not os.path.exists(parent_dir):
        os.makedirs(parent_dir)
    with open(parent_dir + model_name + taks_type + '-rsp.json', 'w') as f:
        f.write(rsp_str)


def save_req_json(model_name, mx_aoi_req, parent_dir='reqs/req_json/'):
    """
    请求json存盘
    :param model_name:
    :param mx_aoi_req:
    :param parent_dir: 存储json的目录
    :return:
    """
    req_str = json.dumps(mx_aoi_req, default=lambda obj: obj.__dict__)
    if not os.path.exists(parent_dir):
        os.makedirs(parent_dir)
    with open(
            parent_dir + model_name + '-' + mx_aoi_req.get('TaskType') +
            '-req.json', 'w') as f:
        f.write(req_str)


def post(req, service_url=SERVER_URL):
    """
    发送请求
    :param req: 请求json
    :param service_url: 服务地址
    :return:
    """
    ca_dir = CA_DIR
    CA_FILE = os.path.join(ca_dir, 'ca.pem')
    # clientCrtFile = os.path.join(ca_dir, 'https.pem')
    # clientKeyFile = os.path.join(ca_dir, 'https-key.pem')
    try:
        r = requests.post(service_url,
                          verify=CA_FILE,
                          # cert=(clientCrtFile, clientKeyFile),
                          json=req,
                          headers={
                              'Content-Type': 'application/json',
                              "uuid": "824D1630-5254-81E7-2702-7C11CB4D1CE4",
                              "ip": "169.254.45.152",
                              "vision-software": "opt",
                              "scene-name": "det"})
        rsp = r.json()

    except ConnectionRefusedError:
        log.error(
            f'server did not response, httpcode=:{r.status_code},  rsp content:{r.text}'
        )
        rsp = {}
    except Exception:
        log.error(
            f'error: rsp is not json, httpcode={r.status_code}, rsp content:{r.text}'
        )
        rsp = {}
    return r.status_code, rsp


def run_one_model(model_name,
                  task_name=TASK_NAME,
                  need_reg=False,
                  service_addr=SERVER_URL,
                  img_path='',
                  project_name='',
                  drawing_box_needed=True,
                  **ext_params):
    """
    运行单模型
    :param model_name: 模型名称-对应项目名称
    :param task_name: 任务名称
    :param need_reg: 是否需要配准，默认为 False
    :param project_name: 项目名
    :param service_addr: 服务地址,默认为 SERVER_URL
    :param drawing_box_needed: 是否需要将返回结果在图片中标注出来
    :return:
    """
    pro_name = get_project_name(
        model_name) if not project_name else project_name
    try:
        req = build_req(model_name,
                        task_name,
                        with_reg=need_reg,
                        pic_path=img_path,
                        project_name=pro_name,
                        **ext_params)
    except Exception:
        log.error('req params is wrong!', exc_info=True)
        return -1, None

    http_code, res = post(req, service_url=service_addr)
    save_rsp_json(model_name, res, taks_type=req.get('TaskType'))
    log.info("%s 返回: %d \n %s" % (model_name, http_code, res))
    if drawing_box_needed:
        draw_boxes(res, img_path)
    return http_code, res


def run_unet(model_name,
             task_name=TASK_NAME,
             need_reg=False,
             service_addr=SERVER_URL,
             img_path='',
             project_name='',
             drawing_box_needed=True,
             cameraId='c24'):
    http_code, res_without_reg = run_one_model(model_name,
                                               task_name=task_name,
                                               need_reg=False,
                                               service_addr=service_addr,
                                               img_path=img_path,
                                               project_name=project_name,
                                               drawing_box_needed=False,
                                               CameraID=cameraId)
    if not need_reg:
        # 无配准模式直接返回结果
        return http_code, res_without_reg
    log.info(f"{model_name} is ready for set template")
    res_tpl = res_without_reg.get(img_path)
    # 带配准模式,获取模板的BlobMask
    if res_tpl.get('BlobMask'):
        tpl_blob_mask = res_tpl.get('BlobMask')
        tpl_blobs = res_tpl.get('Blobs')
    else:
        raise Exception("can not get BlobMask of templates")
    tpl_params = {
        "TaskType": "SetTemplate",
        "MachineId": "bd11",
        "ProductCode": "p223",
        "CameraID": cameraId,
        "Template": {
            "CodeType": "jpg",
            "BlobMask": tpl_blob_mask,
            "Blobs": tpl_blobs,
            "SelectBlobs": [x["ID"] for x in tpl_blobs]
        }
    }
    tpl_req = build_req(model_name,
                        task_name=task_name,
                        pic_path=img_path,
                        project_name=project_name,
                        **tpl_params)
    if 'Files' in tpl_req:
        del tpl_req['Files']
    if 'Patterns' in tpl_req:
        del tpl_req['Patterns']
    # if 'MatchParams' in tpl_req:
    #     del tpl_req['MatchParams']
    if 'RegisterMatrix' in tpl_req:
        del tpl_req['RegisterMatrix']
    save_req_json(project_name, tpl_req, parent_dir='reqs/tpl/')
    # 设置模板，因为有多个Pod副本，因此需要设置多次
    pods = set()
    pod_total = 10000000000
    while len(pods) < pod_total:
        http_code, res = post(tpl_req, service_url=service_addr)
        if http_code != 200:
            raise Exception('{0} set templates error. rsp httpcode={1}'.format(
                model_name, http_code))
        pods.add(res.get('PodId'))
        try:
            pod_total = int(res.get('PodsNum', 0))
        except ValueError:
            log.info(f"get PodsNum error,PodsNum:{res.get('PodsNum')}")
        log.info(
            f"pod({res.get('PodId')}) succeed to set templates for {model_name},res={res}"
        )

    time.sleep(5)
    return run_one_model(model_name,
                         task_name=task_name,
                         need_reg=True,
                         service_addr=service_addr,
                         img_path=img_path,
                         project_name=project_name,
                         drawing_box_needed=drawing_box_needed,
                         CameraID=cameraId)


def run_all_models():
    """
    运行所有模型
    :return:
    """
    for k in model_pro_dic.keys():
        run_one_model(k, need_reg=False)
        run_one_model(k, need_reg=True)


def draw_tag_cost_time(img, res, font=cv2.FONT_HERSHEY_SIMPLEX):
    cost_time = res.get("InferenceTime")
    total_time = cost_time.get("All")
    time_txt_point = (10, 30)
    time_txt = "cost: {:.3f} s".format(total_time)
    cv2.putText(img, time_txt, time_txt_point, font, 1, (0, 255, 0), 4)


def draw_tag_det_cls(img, res, font=cv2.FONT_HERSHEY_SIMPLEX, **ext):
    objects = res.get("Objects")
    for ele in objects:
        if 'Box' in ele and ele.get('Box'):
            box = ele['Box']
            point1 = (box['X'], box['Y'])
            point2 = (box['X'] + box['Width'], box['Y'] + box['Height'])
            cv2.rectangle(img, point1, point2, (0, 255, 0), 4)
        if 'Label' in ele and ele.get('Label'):
            label = ele['Label']
            credence = ele['Score'] if ele['Score'] else ''
            point3 = (point1[0], point1[1] - 10)
            cv2.putText(img, "{} {:.3f}".format(label, credence), point3, font,
                        1, (0, 255, 0), 2)
        if 'Classification' in ele and ele.get('Classification'):
            cls = ele['Classification']
            cls_label = cls['Label']
            cls_credence = cls['Score'] if cls['Score'] else ''
            point4 = (point1[0], point2[1] + 30)
            cv2.putText(img, "{0} {1:.3f}".format(cls_label, cls_credence),
                        point4, font, 1, (0, 255, 190), 2)


def draw_tag_ocr(img, res, font=cv2.FONT_HERSHEY_SIMPLEX):
    line_h = 30
    ocrs = res.get("Texts")
    for i, ocr in enumerate(ocrs):
        ocr_txt = ocr['OCRText']
        credence = ocr['Score'] if ocr['Score'] else ''
        out_txt = "OCRText: {} {:.3f}".format(ocr_txt, credence)
        log.info(f"ocr txt:{out_txt}")
        cv2.putText(img, out_txt, (10, 30 + line_h * (i + 1)), font, 1,
                    (0, 0, 255), 2)


def draw_tag_segmentation(pic_path, save_path, res):
    def fill_color(img, save_path):
        width = img.size[0]
        height = img.size[1]

        def seed_dirt():
            for j in range(height):
                for i in range(width):
                    a = img.getpixel((i, j))
                    if a == 0:
                        return i, j
            print('no seed')

        def LableConnectedRagion4(labelmap, labelindex, quene):
            """
            标记连通区域-4连通
            """
            while len(quene) != 0:
                (m, n) = quene[0]
                quene.remove(quene[0])
                if img.getpixel((m, n + 1)) == 0 and labelmap[n + 1][m] == 0:
                    quene.append((m, n + 1))
                    labelindex += 1
                    labelmap[n + 1][m] = 1
                if img.getpixel((m, n - 1)) == 0 and labelmap[n - 1][m] == 0:
                    quene.append((m, n - 1))
                    labelindex += 1
                    labelmap[n - 1][m] = 1
                if img.getpixel((m + 1, n)) == 0 and labelmap[n][m + 1] == 0:
                    quene.append((m + 1, n))
                    labelindex += 1
                    labelmap[n][m + 1] = 1
                if img.getpixel((m - 1, n)) == 0 and labelmap[n][m - 1] == 0:
                    quene.append((m - 1, n))
                    labelindex += 1
                    labelmap[n][m - 1] = 1

        def save_image():
            for i in range(len(label_map)):
                for j in range(len(label_map[0])):
                    if label_map[i][j] != 0:
                        new_img.putpixel((j, i), 0)
            new_img.save(save_path)

        new_img = Image.new("1", (width, height), 255)
        label_map = np.zeros((height, width))
        label_index = 0
        queue = []
        (x, y) = seed_dirt()
        queue.append((x, y))
        label_index += 1
        label_map[y][x] = 1
        LableConnectedRagion4(label_map, label_index, queue)
        save_image()

    def overlay_blob(image, contours_path, save_path):
        h, w = image.shape[0:2]
        contours = cv2.imread(contours_path)
        h1, w1 = contours.shape[0:2]
        fh, fw = (h / h1), (w / w1)
        print(f"scale: {fw},{fh}")
        overlayer = cv2.resize(contours, (0, 0),
                               fx=fw,
                               fy=fh,
                               interpolation=cv2.INTER_NEAREST)
        output = image.copy()
        alpha = 0.3
        cv2.addWeighted(overlayer, alpha, image, 1 - alpha, 0, output)
        cv2.imwrite(save_path, output)

    img = cv2.imread(pic_path)
    pic_path.rfind('.')
    idx = pic_path.rfind('/')
    contours_path = 'rsps/boxed/contours-{}'.format(pic_path[idx + 1:])
    blob_contours = res.get("BlobContours")
    bi = base64.b64decode(blob_contours)
    contours_img = cv2.imdecode(np.frombuffer(bi, np.uint8), cv2.IMREAD_COLOR)
    rgb_img = Image.fromarray(cv2.cvtColor(contours_img, cv2.COLOR_BGR2GRAY))
    rgb_img.save("rsps/boxed/BlobContours.jpg")
    fill_color(rgb_img, contours_path)
    overlay_blob(img, contours_path, save_path)
    print('draw_tag_segmentation ok')


txt_tag_mapping = {
    'InferenceTime': draw_tag_cost_time,
    'Objects': draw_tag_det_cls,
    'Texts': draw_tag_ocr
}


def draw_txt(field_name, img, res, **ext):
    txt_tag_mapping.get(field_name)(img, res, **ext)


def draw_boxes(rsp, img_path, save_path_parent='rsps/boxed'):
    res = rsp.get(img_path)
    if not rsp:
        log.error('rsp is empty!')
    if not os.path.exists(img_path):
        log.error(f'img_path({img_path}) not exists!')
        return
    if not os.path.exists(save_path_parent):
        os.makedirs(save_path_parent)
    save_pic_path = '%s/boxed-%s' % (save_path_parent,
                                     img_path[img_path.rindex('/') + 1:])
    save_img = None
    if 'BlobContours' in res:
        draw_tag_segmentation(img_path, save_pic_path, res)
        save_img = cv2.imread(save_pic_path)
    img = cv2.imread(img_path) if save_img is None else save_img
    for k in txt_tag_mapping:
        if k in res:
            try:
                draw_txt(k, img, res)
            except Exception:
                log.error('error in txt_tag_mapping', exc_info=True)
    cv2.imwrite(save_pic_path, img)
    cv2.destroyAllWindows()


if __name__ == '__main__':
    # run_all_models()
    run_one_model('yolov4')
